We ensure that the sampling function is explicitly called during the image generation process after obtaining z_mean and z_log_var.¶
Generate different images each time by explicitly calling the sampling function with new random noise.
In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
The process of sampling the latent vector involves generating a random sample from the distribution defined by 𝑧 mean and z log_var. This is done using the reparameterization trick, which allows gradients to flow through the sampling process during training.
In [2]:
# Sampling function with randomness
def sampling(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
In [3]:
# VAE Class with custom call method
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.keras.losses.binary_crossentropy(inputs, reconstructed)
)
reconstruction_loss *= 200 * 200 * 3
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
self.add_loss(reconstruction_loss + kl_loss)
return reconstructed
In [ ]:
In [4]:
# Encoder
latent_dim = 2
encoder_inputs = keras.Input(shape=(200, 200, 3))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = layers.Lambda(lambda args: sampling(*args), output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 200, 200, 3) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d (Conv2D) │ (None, 100, 100, 32) │ 896 │ input_layer[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d_1 (Conv2D) │ (None, 50, 50, 64) │ 18,496 │ conv2d[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ flatten (Flatten) │ (None, 160000) │ 0 │ conv2d_1[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ dense (Dense) │ (None, 16) │ 2,560,016 │ flatten[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_mean (Dense) │ (None, 2) │ 34 │ dense[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_log_var (Dense) │ (None, 2) │ 34 │ dense[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z (Lambda) │ (None, 2) │ 0 │ z_mean[0][0], │ │ │ │ │ z_log_var[0][0] │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,579,476 (9.84 MB)
Trainable params: 2,579,476 (9.84 MB)
Non-trainable params: 0 (0.00 B)
In [ ]:
In [5]:
# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(50 * 50 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((50, 50, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()
Model: "decoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 2) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_1 (Dense) │ (None, 160000) │ 480,000 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ reshape (Reshape) │ (None, 50, 50, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose (Conv2DTranspose) │ (None, 100, 100, 64) │ 36,928 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_1 (Conv2DTranspose) │ (None, 200, 200, 32) │ 18,464 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_2 (Conv2DTranspose) │ (None, 200, 200, 3) │ 867 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 536,259 (2.05 MB)
Trainable params: 536,259 (2.05 MB)
Non-trainable params: 0 (0.00 B)
In [ ]:
In [6]:
# VAE Model
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
vae.build(input_shape=(None, 200, 200, 3))
vae.summary()
# Prepare and normalize the image
pic_1 = keras.preprocessing.image.load_img('pic_1.jpeg', target_size=(200, 200))
pic_1 = keras.preprocessing.image.img_to_array(pic_1).astype("float32") / 255
pic_1 = np.expand_dims(pic_1, 0)
# Train the VAE model
vae.fit(pic_1, epochs=100, batch_size=1)
Model: "vae"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ encoder (Functional) │ ? │ 2,579,476 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ decoder (Functional) │ ? │ 536,259 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 3,115,735 (11.89 MB)
Trainable params: 3,115,735 (11.89 MB)
Non-trainable params: 0 (0.00 B)
Epoch 1/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - loss: 83177.3750 Epoch 2/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 83199.9844 Epoch 3/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 83137.7109 Epoch 4/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 83081.3359 Epoch 5/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 83006.4609 Epoch 6/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 82808.2734 Epoch 7/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 82488.0391 Epoch 8/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 82183.8203 Epoch 9/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 81755.2578 Epoch 10/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 80598.4375 Epoch 11/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 79338.1484 Epoch 12/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 78044.2891 Epoch 13/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 76914.7656 Epoch 14/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 75710.9375 Epoch 15/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 74347.7656 Epoch 16/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 73219.6562 Epoch 17/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - loss: 72456.3281 Epoch 18/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 71910.6094 Epoch 19/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 71444.3828 Epoch 20/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 71049.1016 Epoch 21/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 70710.1562 Epoch 22/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 70397.3125 Epoch 23/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 70146.9375 Epoch 24/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 69991.2422 Epoch 25/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 69869.0469 Epoch 26/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 69733.7891 Epoch 27/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 69622.5547 Epoch 28/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 69523.6719 Epoch 29/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 69410.1484 Epoch 30/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 69317.9375 Epoch 31/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 88ms/step - loss: 69252.4141 Epoch 32/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 69166.5156 Epoch 33/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 69107.1328 Epoch 34/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 69046.4141 Epoch 35/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 68993.1875 Epoch 36/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 68969.3984 Epoch 37/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 68947.7969 Epoch 38/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 68869.1641 Epoch 39/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 68868.7969 Epoch 40/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 68825.5625 Epoch 41/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 68771.5547 Epoch 42/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 68728.8594 Epoch 43/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 68762.8047 Epoch 44/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 68681.9766 Epoch 45/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 68623.1094 Epoch 46/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 68669.6094 Epoch 47/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 68528.6094 Epoch 48/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 68514.2266 Epoch 49/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 68489.3438 Epoch 50/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 68406.0469 Epoch 51/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 68353.9297 Epoch 52/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 68332.2188 Epoch 53/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 68261.9141 Epoch 54/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 68183.1953 Epoch 55/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 68140.0703 Epoch 56/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 68095.0469 Epoch 57/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 67997.8047 Epoch 58/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 67951.8047 Epoch 59/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 67896.7266 Epoch 60/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67806.6484 Epoch 61/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 67762.4375 Epoch 62/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67698.6406 Epoch 63/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 67666.8828 Epoch 64/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - loss: 67597.9141 Epoch 65/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 67606.9609 Epoch 66/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 67540.2578 Epoch 67/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 67496.0938 Epoch 68/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67560.3906 Epoch 69/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 67456.5078 Epoch 70/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 67455.4766 Epoch 71/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 67439.9688 Epoch 72/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67391.3828 Epoch 73/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67367.6094 Epoch 74/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 67405.2031 Epoch 75/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 67329.3984 Epoch 76/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 67312.9844 Epoch 77/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 67310.9297 Epoch 78/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 101ms/step - loss: 67284.5234 Epoch 79/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 67261.4219 Epoch 80/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 67256.9219 Epoch 81/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67239.6406 Epoch 82/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 67215.1641 Epoch 83/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 67207.1094 Epoch 84/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 67196.3984 Epoch 85/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 67181.4922 Epoch 86/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 67167.9766 Epoch 87/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 67180.4688 Epoch 88/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 67152.2734 Epoch 89/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 67145.4922 Epoch 90/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 67135.4688 Epoch 91/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 67125.2109 Epoch 92/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 67115.7500 Epoch 93/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 67105.8438 Epoch 94/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67106.8203 Epoch 95/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - loss: 67091.8750 Epoch 96/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 67082.9375 Epoch 97/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 67074.0703 Epoch 98/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 67084.1484 Epoch 99/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 67063.2578 Epoch 100/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 67076.4297
Out[6]:
<keras.src.callbacks.history.History at 0x1c974460b50>
In [ ]:
In [7]:
# Show plot and save image function
def show_and_save_plot(image, save_path):
plt.imshow(image.squeeze())
plt.axis('off')
plt.savefig(save_path)
plt.show()
for i in range(10):
z_mean, z_log_var, _ = vae.encoder.predict(pic_1)
# Introduce randomness in sampling
encoded_imgs = sampling(z_mean, z_log_var).numpy()
decoded_imgs = vae.decoder.predict(encoded_imgs)
save_path = f'generated_image_{i + 1}.png'
show_and_save_plot(decoded_imgs, save_path)
print("Images have been saved and displayed successfully.")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 93ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 34ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 34ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
Images have been saved and displayed successfully.
In [ ]:
In [ ]:
What output do you see? Explain why¶
For the codes above, why when it is generating black/white 10 images, the 10 images colors can be different but when it is generating colorful 10 images, the 10 images look the same?¶
In [8]:
z_mean, z_log_var, _ = vae.encoder.predict(pic_1)
print("z_mean:", z_mean)
print("z_log_var:", z_log_var)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step z_mean: [[-12.294494 -5.9670925]] z_log_var: [[-10.742041 -2.4054945]]
In [9]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
# Sampling function with randomness
def sampling(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
# VAE Class with custom call method
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.mse = tf.keras.losses.MeanSquaredError()
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
self.mse(inputs, reconstructed)
)
reconstruction_loss *= 200 * 200 * 3
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
self.add_loss(reconstruction_loss + kl_loss)
return reconstructed
# Encoder
latent_dim = 10 # Increased latent dimension
encoder_inputs = keras.Input(shape=(200, 200, 3))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = layers.Lambda(lambda args: sampling(*args), output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(50 * 50 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((50, 50, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()
# VAE Model
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
vae.build(input_shape=(None, 200, 200, 3))
vae.summary()
# Prepare and normalize the image
pic_1 = keras.preprocessing.image.load_img('pic_1.jpeg', target_size=(200, 200))
pic_1 = keras.preprocessing.image.img_to_array(pic_1).astype("float32") / 255
pic_1 = np.expand_dims(pic_1, 0)
# Train the VAE model
vae.fit(pic_1, epochs=100, batch_size=1)
# Show plot and save image function
def show_and_save_plot(image, save_path):
plt.imshow(image.squeeze())
plt.axis('off')
plt.savefig(save_path)
plt.show()
# Run the encoder-decoder multiple times to generate different copies
for i in range(10):
z_mean, z_log_var, _ = vae.encoder.predict(pic_1)
# Introduce randomness in sampling
encoded_imgs = sampling(z_mean, z_log_var).numpy()
decoded_imgs = vae.decoder.predict(encoded_imgs)
save_path = f'generated_image_{i + 1}.png'
show_and_save_plot(decoded_imgs, save_path)
print("Images have been saved and displayed successfully.")
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_2 (InputLayer) │ (None, 200, 200, 3) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d_2 (Conv2D) │ (None, 100, 100, 32) │ 896 │ input_layer_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d_3 (Conv2D) │ (None, 50, 50, 64) │ 18,496 │ conv2d_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ flatten_1 (Flatten) │ (None, 160000) │ 0 │ conv2d_3[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ dense_2 (Dense) │ (None, 16) │ 2,560,016 │ flatten_1[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_mean (Dense) │ (None, 10) │ 170 │ dense_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_log_var (Dense) │ (None, 10) │ 170 │ dense_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z (Lambda) │ (None, 10) │ 0 │ z_mean[0][0], │ │ │ │ │ z_log_var[0][0] │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,579,748 (9.84 MB)
Trainable params: 2,579,748 (9.84 MB)
Non-trainable params: 0 (0.00 B)
Model: "decoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ input_layer_3 (InputLayer) │ (None, 10) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_3 (Dense) │ (None, 160000) │ 1,760,000 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ reshape_1 (Reshape) │ (None, 50, 50, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_3 (Conv2DTranspose) │ (None, 100, 100, 64) │ 36,928 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_4 (Conv2DTranspose) │ (None, 200, 200, 32) │ 18,464 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_5 (Conv2DTranspose) │ (None, 200, 200, 3) │ 867 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 1,816,259 (6.93 MB)
Trainable params: 1,816,259 (6.93 MB)
Non-trainable params: 0 (0.00 B)
Model: "vae_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ encoder (Functional) │ ? │ 2,579,748 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ decoder (Functional) │ ? │ 1,816,259 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 4,396,007 (16.77 MB)
Trainable params: 4,396,007 (16.77 MB)
Non-trainable params: 0 (0.00 B)
Epoch 1/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - loss: 7610.7393 Epoch 2/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 7726.1265 Epoch 3/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 7600.4229 Epoch 4/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 7594.1895 Epoch 5/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 7586.2324 Epoch 6/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step - loss: 7585.9980 Epoch 7/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 7580.5942 Epoch 8/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - loss: 7562.7554 Epoch 9/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 7562.2412 Epoch 10/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 7558.2412 Epoch 11/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 7538.1782 Epoch 12/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 83ms/step - loss: 7530.4229 Epoch 13/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 77ms/step - loss: 7481.6230 Epoch 14/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 108ms/step - loss: 7464.0249 Epoch 15/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 92ms/step - loss: 7488.7236 Epoch 16/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - loss: 7442.9277 Epoch 17/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 7459.9717 Epoch 18/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 97ms/step - loss: 7406.3745 Epoch 19/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 172ms/step - loss: 7363.2588 Epoch 20/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 7301.2778 Epoch 21/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 80ms/step - loss: 7249.5308 Epoch 22/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step - loss: 7515.9595 Epoch 23/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 92ms/step - loss: 7032.7549 Epoch 24/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 83ms/step - loss: 7185.9648 Epoch 25/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 6800.1313 Epoch 26/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 7498.6348 Epoch 27/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 87ms/step - loss: 7238.9546 Epoch 28/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 6345.5332 Epoch 29/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 77ms/step - loss: 6500.5693 Epoch 30/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 6510.3496 Epoch 31/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 84ms/step - loss: 6107.6074 Epoch 32/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 90ms/step - loss: 5453.5703 Epoch 33/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 87ms/step - loss: 5963.7236 Epoch 34/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 86ms/step - loss: 5101.7266 Epoch 35/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 92ms/step - loss: 5312.6353 Epoch 36/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 81ms/step - loss: 5710.2451 Epoch 37/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 4759.2271 Epoch 38/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 4363.7920 Epoch 39/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 5299.1719 Epoch 40/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 4659.5791 Epoch 41/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 4278.9546 Epoch 42/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 4238.8711 Epoch 43/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 3939.6272 Epoch 44/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 4250.3989 Epoch 45/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 3802.0754 Epoch 46/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 4096.0186 Epoch 47/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 3790.0337 Epoch 48/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 84ms/step - loss: 4175.6968 Epoch 49/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 99ms/step - loss: 4180.9517 Epoch 50/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 4046.8005 Epoch 51/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 3911.4573 Epoch 52/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 3511.0881 Epoch 53/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 3465.3376 Epoch 54/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 3440.6035 Epoch 55/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 3532.6001 Epoch 56/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 3305.6157 Epoch 57/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 3245.9075 Epoch 58/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 3229.1716 Epoch 59/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 3110.1326 Epoch 60/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 3692.0945 Epoch 61/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 2908.8545 Epoch 62/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 3116.0874 Epoch 63/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 2801.7007 Epoch 64/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 2918.8108 Epoch 65/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 2768.5195 Epoch 66/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 2687.1475 Epoch 67/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 2752.2695 Epoch 68/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 3081.9973 Epoch 69/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 2474.6143 Epoch 70/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 2219.4907 Epoch 71/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 2633.3442 Epoch 72/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 2011.7100 Epoch 73/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 1995.2678 Epoch 74/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 2160.3147 Epoch 75/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 2189.7275 Epoch 76/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 1904.3241 Epoch 77/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 1769.3903 Epoch 78/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step - loss: 1916.2294 Epoch 79/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 2144.8218 Epoch 80/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 2341.8989 Epoch 81/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 1934.7032 Epoch 82/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 1598.4662 Epoch 83/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 2005.9113 Epoch 84/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - loss: 2034.1517 Epoch 85/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step - loss: 3048.6072 Epoch 86/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - loss: 2552.8713 Epoch 87/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 1493.8463 Epoch 88/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 1886.9950 Epoch 89/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 1611.8395 Epoch 90/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 2299.6582 Epoch 91/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 77ms/step - loss: 1591.7374 Epoch 92/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 1661.3323 Epoch 93/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 2869.6941 Epoch 94/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 1418.2002 Epoch 95/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 1554.6145 Epoch 96/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 1910.1365 Epoch 97/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 1320.2623 Epoch 98/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 1249.3489 Epoch 99/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 76ms/step - loss: 1830.5522 Epoch 100/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 92ms/step - loss: 1435.1040 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 91ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 85ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 27ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 36ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
Images have been saved and displayed successfully.
In [ ]:
In [ ]:
In [ ]:
In [ ]: